{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1SnEBBpYsOhK"
      },
      "source": [
        "Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "97PCCHRrsi8k"
      },
      "outputs": [],
      "source": [
        "#@title Full license text\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "# \n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "# \n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tHpX8QhAhilt"
      },
      "source": [
        "# LSTMs in Haiku\n",
        "\n",
        "**[Haiku](https://github.com/deepmind/dm-haiku) is a simple neural network library for [JAX](https://github.com/jax-ml/jax).**\n",
        "\n",
        "This notebook walks through a simple LSTM in JAX with Haiku.\n",
        "\n",
        "For first-time Haiku users, we recommend that you first check out out our [Quickstart](https://github.com/deepmind/dm-haiku#quickstart) and [MNIST example](https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py) first."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tZQ5MRaBiSAC"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "JblB4wS7iJVa"
      },
      "outputs": [],
      "source": [
        "!pip install dm-haiku optax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1lGChhE1hPYW"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "from typing import Tuple, TypeVar\n",
        "import warnings\n",
        "\n",
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import plotnine as gg\n",
        "\n",
        "T = TypeVar('T')\n",
        "Pair = Tuple[T, T]\n",
        "\n",
        "gg.theme_set(gg.theme_bw())\n",
        "warnings.filterwarnings('ignore')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_gNhADeoiVAY"
      },
      "source": [
        "## Generating Data\n",
        "\n",
        "In this notebook, we generate many sine waves (of the same period), and try to predict the next value in the wave based on its previous values.\n",
        "\n",
        "For simplicity, we generate static-sized datasets and wrap them with an iterator-based API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "j7aLjpFKiT2s"
      },
      "outputs": [],
      "source": [
        "def sine_seq(\n",
        "    phase: float,\n",
        "    seq_len: int,\n",
        "    samples_per_cycle: int,\n",
        ") -> Pair[np.ndarray]:\n",
        "  \"\"\"Returns x, y in [T, B] tensor.\"\"\"\n",
        "  t = np.arange(seq_len + 1) * (2 * math.pi / samples_per_cycle)\n",
        "  t = t.reshape([-1, 1]) + phase\n",
        "  sine_t = np.sin(t)\n",
        "  return sine_t[:-1, :], sine_t[1:, :]\n",
        "\n",
        "\n",
        "def generate_data(\n",
        "    seq_len: int,\n",
        "    train_size: int,\n",
        "    valid_size: int,\n",
        ") -> Pair[Pair[np.ndarray]]:\n",
        "  phases = np.random.uniform(0., 2 * math.pi, [train_size + valid_size])\n",
        "  all_x, all_y = sine_seq(phases, seq_len, 3 * seq_len / 4)\n",
        "\n",
        "  all_x = np.expand_dims(all_x, -1)\n",
        "  all_y = np.expand_dims(all_y, -1)\n",
        "  train_x = all_x[:, :train_size]\n",
        "  train_y = all_y[:, :train_size]\n",
        "\n",
        "  valid_x = all_x[:, train_size:]\n",
        "  valid_y = all_y[:, train_size:]\n",
        "\n",
        "  return (train_x, train_y), (valid_x, valid_y)\n",
        "\n",
        "\n",
        "class Dataset:\n",
        "  \"\"\"An iterator over a numpy array, revealing batch_size elements at a time.\"\"\"\n",
        "\n",
        "  def __init__(self, xy: Pair[np.ndarray], batch_size: int):\n",
        "    self._x, self._y = xy\n",
        "    self._batch_size = batch_size\n",
        "    self._length = self._x.shape[1]\n",
        "    self._idx = 0\n",
        "    if self._length % batch_size != 0:\n",
        "      msg = 'dataset size {} must be divisible by batch_size {}.'\n",
        "      raise ValueError(msg.format(self._length, batch_size))\n",
        "\n",
        "  def __next__(self) -> Pair[np.ndarray]:\n",
        "    start = self._idx\n",
        "    end = start + self._batch_size\n",
        "    x, y = self._x[:, start:end], self._y[:, start:end]\n",
        "    if end >= self._length:\n",
        "      end = end % self._length\n",
        "      assert end == 0  # Guaranteed by ctor assertion.\n",
        "    self._idx = end\n",
        "    return x, y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Av3gipvVi_49"
      },
      "outputs": [],
      "source": [
        "TRAIN_SIZE = 2 ** 14\n",
        "VALID_SIZE = 128\n",
        "BATCH_SIZE = 8\n",
        "SEQ_LEN = 64\n",
        "\n",
        "train, valid = generate_data(SEQ_LEN, TRAIN_SIZE, VALID_SIZE)\n",
        "\n",
        "# Plot an observation/target pair.\n",
        "df = pd.DataFrame({'x': train[0][:, 0, 0], 'y': train[1][:, 0, 0]}).reset_index()\n",
        "df = pd.melt(df, id_vars=['index'], value_vars=['x', 'y'])\n",
        "plot = gg.ggplot(df) + gg.aes(x='index', y='value', color='variable') + gg.geom_line()\n",
        "plot.draw()\n",
        "\n",
        "train_ds = Dataset(train, BATCH_SIZE)\n",
        "valid_ds = Dataset(valid, BATCH_SIZE)\n",
        "del train, valid  # Don't leak temporaries."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LZGw5Jdvjmqh"
      },
      "source": [
        "## Training an LSTM\n",
        "\n",
        "To train the LSTM, we define a Haiku function which unrolls the LSTM over the input sequence, generating predictions for all output values. The LSTM always starts with its initial state at the start of the sequence.\n",
        "\n",
        "The Haiku function is then transformed into a pure function through `hk.transform`, and is trained with Adam on an L2 prediction loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "nacnTj5ejIK5"
      },
      "outputs": [],
      "source": [
        "def unroll_net(seqs: jax.Array):\n",
        "  \"\"\"Unrolls an LSTM over seqs, mapping each output to a scalar.\"\"\"\n",
        "  # seqs is [T, B, F].\n",
        "  core = hk.LSTM(32)\n",
        "  batch_size = seqs.shape[1]\n",
        "  outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))\n",
        "  # We could include this Linear as part of the recurrent core!\n",
        "  # However, it's more efficient on modern accelerators to run the linear once\n",
        "  # over the entire sequence than once per sequence element.\n",
        "  return hk.BatchApply(hk.Linear(1))(outs), state\n",
        "\n",
        "\n",
        "model = hk.transform(unroll_net)\n",
        "\n",
        "\n",
        "def train_model(train_ds: Dataset, valid_ds: Dataset) -> hk.Params:\n",
        "  \"\"\"Initializes and trains a model on train_ds, returning the final params.\"\"\"\n",
        "  rng = jax.random.PRNGKey(428)\n",
        "  opt = optax.adam(1e-3)\n",
        "\n",
        "  @jax.jit\n",
        "  def loss(params, x, y):\n",
        "    pred, _ = model.apply(params, None, x)\n",
        "    return jnp.mean(jnp.square(pred - y))\n",
        "\n",
        "  @jax.jit\n",
        "  def update(step, params, opt_state, x, y):\n",
        "    l, grads = jax.value_and_grad(loss)(params, x, y)\n",
        "    grads, opt_state = opt.update(grads, opt_state)\n",
        "    params = optax.apply_updates(params, grads)\n",
        "    return l, params, opt_state\n",
        "\n",
        "  # Initialize state.\n",
        "  sample_x, _ = next(train_ds)\n",
        "  params = model.init(rng, sample_x)\n",
        "  opt_state = opt.init(params)\n",
        "\n",
        "  for step in range(2001):\n",
        "    if step % 100 == 0:\n",
        "      x, y = next(valid_ds)\n",
        "      print(\"Step {}: valid loss {}\".format(step, loss(params, x, y)))\n",
        "\n",
        "    x, y = next(train_ds)\n",
        "    train_loss, params, opt_state = update(step, params, opt_state, x, y)\n",
        "    if step % 100 == 0:\n",
        "      print(\"Step {}: train loss {}\".format(step, train_loss))\n",
        "\n",
        "  return params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "AssgDctokbl5"
      },
      "outputs": [],
      "source": [
        "trained_params = train_model(train_ds, valid_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yr7jrOL3ki-b"
      },
      "source": [
        "## Sampling\n",
        "\n",
        "The point of training models is so that they can make predictions! How can we generate predictions with the trained model?\n",
        "\n",
        "If we're allowed to feed in the ground truth, we can just run the original model's `apply` function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "f2qETEqXLT1N"
      },
      "outputs": [],
      "source": [
        "def plot_samples(truth: np.ndarray, prediction: np.ndarray) -> gg.ggplot:\n",
        "  assert truth.shape == prediction.shape\n",
        "  df = pd.DataFrame({'truth': truth.squeeze(), 'predicted': prediction.squeeze()}).reset_index()\n",
        "  df = pd.melt(df, id_vars=['index'], value_vars=['truth', 'predicted'])\n",
        "  plot = (\n",
        "      gg.ggplot(df)\n",
        "      + gg.aes(x='index', y='value', color='variable')\n",
        "      + gg.geom_line()\n",
        "  )\n",
        "  return plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "KOuK1egilGD0"
      },
      "outputs": [],
      "source": [
        "# Grab a sample from the validation set.\n",
        "sample_x, _ = next(valid_ds)\n",
        "sample_x = sample_x[:, :1]  # Shrink to batch-size 1.\n",
        "\n",
        "# Generate a prediction, feeding in ground truth at each point as input.\n",
        "predicted, _ = model.apply(trained_params, None, sample_x)\n",
        "\n",
        "plot = plot_samples(sample_x[1:], predicted[:-1])\n",
        "plot.draw()\n",
        "del sample_x, predicted\n",
        "\n",
        "# Typically: the beginning of the predictions are a bit wonky, but the curve\n",
        "# quickly smoothes out."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tDyGshz_lwrM"
      },
      "source": [
        "If we can't feed in the ground truth (because we don't have it), we can also run the model autoregressively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Cg8oQ75Ulvld"
      },
      "outputs": [],
      "source": [
        "def autoregressive_predict(\n",
        "    trained_params: hk.Params,\n",
        "    context: jax.Array,\n",
        "    seq_len: int,\n",
        "):\n",
        "  \"\"\"Given a context, autoregressively generate the rest of a sine wave.\"\"\"\n",
        "  ar_outs = []\n",
        "  context = jax.device_put(context)\n",
        "  for _ in range(seq_len - context.shape[0]):\n",
        "    full_context = jnp.concatenate([context] + ar_outs)\n",
        "    outs, _ = jax.jit(model.apply)(trained_params, None, full_context)\n",
        "    # Append the newest prediction to ar_outs.\n",
        "    ar_outs.append(outs[-1:])\n",
        "  # Return the final full prediction.\n",
        "  return outs\n",
        "\n",
        "\n",
        "sample_x, _ = next(valid_ds)\n",
        "context_length = SEQ_LEN // 8\n",
        "# Cut the batch-size 1 context from the start of the sequence.\n",
        "context = sample_x[:context_length, :1]\n",
        "\n",
        "# We can reuse params we got from training for inference - as long as the\n",
        "# declaration order is the same.\n",
        "predicted = autoregressive_predict(trained_params, context, SEQ_LEN)\n",
        "\n",
        "plot = plot_samples(sample_x[1:, :1], predicted)\n",
        "plot += gg.geom_vline(xintercept=len(context), linetype='dashed')\n",
        "plot.draw()\n",
        "del predicted"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qGkr2gf2oALo"
      },
      "source": [
        "### Sharing parameters with a different function.\n",
        "\n",
        "Unfortunately, this is a bit slow - we're doing O(N^2) computation for a sequence of length N.\n",
        "\n",
        "It'd be better if we could do the autoregressive sampling all at once - but we need to write a new Haiku function for that.\n",
        "\n",
        "We're in luck - if the Haiku module names match, the same parameters can be used for multiple Haiku functions.\n",
        "\n",
        "This can be achieved through a combination of two techniques:\n",
        "\n",
        "1. If we manually give a unique name to a module, we can ensure that the parameters are directed to the right places.\n",
        "2. If modules are instantiated in the same order, they'll have the same names in different functions.\n",
        "\n",
        "Here, we rely on method #2 to create a fast autoregressive prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WdKcHr6_n_ba"
      },
      "outputs": [],
      "source": [
        "def fast_autoregressive_predict_fn(context, seq_len):\n",
        "  \"\"\"Given a context, autoregressively generate the rest of a sine wave.\"\"\"\n",
        "  core = hk.LSTM(32)\n",
        "  dense = hk.Linear(1)\n",
        "  state = core.initial_state(context.shape[1])\n",
        "  # Unroll over the context using `hk.dynamic_unroll`.\n",
        "  # As before, we `hk.BatchApply` the Linear for efficiency.\n",
        "  context_outs, state = hk.dynamic_unroll(core, context, state)\n",
        "  context_outs = hk.BatchApply(dense)(context_outs)\n",
        "\n",
        "  # Now, unroll one step at a time using the running recurrent state.\n",
        "  ar_outs = []\n",
        "  x = context_outs[-1]\n",
        "  for _ in range(seq_len - context.shape[0]):\n",
        "    x, state = core(x, state)\n",
        "    x = dense(x)\n",
        "    ar_outs.append(x)\n",
        "  return jnp.concatenate([context_outs, jnp.stack(ar_outs)])\n",
        "\n",
        "\n",
        "fast_ar_predict = hk.transform(fast_autoregressive_predict_fn)\n",
        "fast_ar_predict = jax.jit(fast_ar_predict.apply, static_argnums=3)\n",
        "# Reuse the same context from the previous cell.\n",
        "predicted = fast_ar_predict(trained_params, None, context, SEQ_LEN)\n",
        "# The plots should be equivalent!\n",
        "plot = plot_samples(sample_x[1:, :1], predicted[:-1])\n",
        "plot += gg.geom_vline(xintercept=len(context), linetype='dashed')\n",
        "plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9S0tkPXGrU3a"
      },
      "outputs": [],
      "source": [
        "%timeit autoregressive_predict(trained_params, context, SEQ_LEN)\n",
        "%timeit fast_ar_predict(trained_params, None, context, SEQ_LEN)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "haiku-lstms.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
